# 1.输入图像预处理，包括尺寸，旋转。
# 2.真实值ground truth变形，shape = (w,h,kp_num) = (224, 224, 24)
# 3.返回一个发生器，用于给模型做输入，以及输出时做损失计算。
import os
import numpy as np
import pandas as pd
import torch
from skimage import io, transform  # 用于图像的IO和变换
from torch.utils.data import Dataset, DataLoader
from scripts.kp.GassionHeatMap import generate_hmap_mask
from torchvision import transforms

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class KeyPointsScaleDataSet(Dataset):
    """服装关键点标记数据集"""

    def __init__(self, csv_file, root_dir, num, transforms_img=None):
        """
        初始化数据集
        :param csv_file: 带标记的csv文件，为数据-category-标签coordination 成对组成的文件
        :param root_dir: 图像数据目录
        :param num: 关键点个数
        :param transforms_img（callable,optional）: 一个样本上的可用可选变换
        """
        self.data_info = self.get_file_info(csv_file, num)
        self.root_dir = root_dir
        self.transform_img = transforms_img

    def __len__(self):
        return len(self.data_info[0])

    def __getitem__(self, idx):
        H, W = 64.0, 64.0
        img_id = os.path.join(self.root_dir, self.data_info[0][idx])
        image = io.imread(img_id)
        h, w, c = image.shape
        landmarks = np.asfortranarray(self.data_info[1][idx])
        image = self.change_img_size(image, H, W)
        landmarks[:, 0] = landmarks[:, 0] * W / w
        landmarks[:, 1] = landmarks[:, 1] * H / h
        image = self.transform_img(image) / 255.0
        scale = self.data_info[2][idx]
        # print(image, landmarks, scale)
        return image.float(), torch.tensor(landmarks), torch.tensor([w, h]), torch.tensor(scale)

    @staticmethod
    def get_file_info(file_path, num):
        file_info = pd.read_csv(file_path)
        img_list = file_info.iloc[:, 0]
        landmarks = file_info.iloc[:, 2:num + 2].values  # panda中DataFrame数据的读取

        coordinarys = []
        for i in range(len(landmarks)):
            label = []
            for j in range(num):
                coor = [int(p) for p in landmarks[i][j].split('_')]
                label.append(coor)
            coordinarys.append(np.concatenate(label))
        landmarks = np.array(coordinarys).reshape((-1, num, 3))

        scale_list = file_info.iloc[:, -1]
        return img_list, landmarks, scale_list

    @staticmethod
    def change_img_size(image, h, w):
        return transform.resize(image, (h, w))


class KeyPointsDataSet(Dataset):
    """服装关键点标记数据集"""

    def __init__(self, csv_file, root_dir, transform_img=None, transform_heat=None):
        """
        初始化数据集
        :param csv_file: 带标记的csv文件，为数据-category-标签coordination 成对组成的文件
        :param root_dir: 图像数据目录
        :param transform（callable,optional）: 一个样本上的可用可选变换
        """
        self.data_info = self.get_file_info(csv_file)
        self.root_dir = root_dir
        self.transform_img = transform_img
        self.transform_heat = transform_heat

    def __len__(self):
        return len(self.data_info[0])

    def __getitem__(self, idx):
        img_id = self.data_info[0][idx]
        img_id = os.path.join(self.root_dir, img_id)
        image = io.imread(img_id)
        o_image_size = image.shape[0:2]
        landmarks = np.asfortranarray(self.data_info[1][idx])
        heatmap = self.get_htmap(o_image_size, landmarks)
        # print("Before, image size is:", image.shape)
        # print("Before, heatmap size is:", heatmap.shape)

        image = self.change_img_size(image)
        heatmap = self.change_heat_size(heatmap)

        try:
            if self.transform_img and self.transform_heat:
                image = self.transform_img(image) / 255
                # print("After:image size is:", image.shape)
                new_size = image.shape[1:]
                bi = np.array(new_size) / np.array(o_image_size)
                landmarks[:, 0:2] = landmarks[:, 0:2] * bi
                heatmap = self.transform_heat(heatmap)
                # print("After:heatmap size is", heatmap.shape)
        except:
            print("on,here!")
            # raise EOFError
            image = self.transform_img(image) / 255
            heatmap = torch.tensor(heatmap)
        finally:
            pass

        return image.float(), torch.tensor(landmarks), heatmap

    @staticmethod
    def get_file_info(file_path):
        file_info = pd.read_csv(file_path)
        img_list = file_info.iloc[:, 0]
        landmarks = file_info.iloc[:, 2:26].values  # panda中DataFrame数据的读取

        coordinarys = []
        for i in range(len(landmarks)):
            label = []
            for j in range(24):
                plot = landmarks[i][j].split('_')
                coor = []
                for per in plot:
                    coor.append(int(per))
                label.append(coor)
            coordinarys.append(np.concatenate(label))
        landmarks = np.array(coordinarys).reshape((-1, 24, 3))
        return img_list, landmarks

    @staticmethod
    def get_htmap(image_size, landmarks):
        hmap = generate_hmap_mask(image_size, landmarks)
        return hmap

    @staticmethod
    def change_heat_size(image):
        chn, h, w = image.shape
        n_image = np.zeros((chn, 512, 512))

        if h <= 512:
            if w <= 512:
                n_image[:, 0:h, 0:w] = image
                return n_image
            else:
                n_image[:, 0:h, :] = image[:, :, 0:512]
                return n_image
        else:
            if w <= 512:
                n_image[:, :, 0:w] = image[:, 0:512, :]
                return n_image
            else:
                n_image[:, :, :] = image[:, 0:512, 0:512]
                return n_image

    @staticmethod
    def change_img_size(image):
        h, w, chn = image.shape
        # print(image.shape)
        n_image = np.zeros((512, 512, chn))

        if h <= 512:
            if w <= 512:
                n_image[0:h, 0:w, :] = image
                return n_image
            else:
                n_image[0:h, :, :] = image[:, 0:512, :]
                return n_image
        else:
            if w <= 512:
                n_image[:, 0:w, :] = image[0:512, :, :]
                return n_image
            else:
                n_image[:, :, :] = image[0:512, 0:512, :]
                return n_image


class DataSet_Test(KeyPointsDataSet):
    def __init__(self, csv_file, root_dir, transform_img):
        super().__init__(csv_file, root_dir, transform_img)

    def __getitem__(self, idx):
        img_id = self.data_info[idx]

        img_id = os.path.join(self.root_dir, img_id)
        image = io.imread(img_id)
        image = self.change_img_size(image)
        # print(image)
        try:
            if self.transform_img:
                image = self.transform_img(image) / 255
        except:
            print("on,here!")
            image = self.transform_img(image) / 255
        finally:
            # print(image)
            pass
        return image.float()

    @staticmethod
    def get_file_info(file_path):
        file_info = pd.read_csv(file_path)
        img_list = file_info.iloc[:, 0]

        return img_list


class ToTensor(object):
    """将样本中的ndarrays转换为Tensors."""

    def __call__(self, sample):
        return torch.from_numpy(sample)


transform_img = transforms.Compose([
    transforms.ToTensor(),  # 将图像(Image)转成Tensor,归一化[0,1]
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 将tensor标准化[-1,1]
])
transform_heat = transforms.Compose([
    ToTensor(),  # 将图像(Image)转成Tensor,归一化[0,1]
])

if __name__ == "__main__":
    # fashionDataset = KeyPointsDataSet(csv_file=r"E:\Datasets\Fashion\Fashion AI-keypoints\test\test.csv",
    #                                   root_dir=r"E:\Datasets\Fashion\Fashion AI-keypoints\test",
    #                                   transform_img=transform_img,
    #                                   transform_heat=transform_heat
    #                                   )
    # dataloader = DataLoader(dataset=fashionDataset, batch_size=4)
    # for i_batch, data in enumerate(dataloader):
    #     img, landmarks, hmap = data
    #     print(type(img), type(landmarks), type(hmap))
    #     print(img.shape, landmarks.shape, hmap.shape)
    #     if i_batch == 1:
    #         break

    test_fashionDataset = DataSet_Test(csv_file=r"E:\Datasets\Fashion\Fashion AI-keypoints\test\test.csv",
                                       root_dir=r"E:\Datasets\Fashion\Fashion AI-keypoints\test",
                                       transform_img=transform_img)
    test_dataloader = DataLoader(dataset=test_fashionDataset, batch_size=4)
    for i_batch, data in enumerate(test_dataloader):
        img = data
        # print(img)
        # print(img.shape)
        if i_batch == 1:
            break
